{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EvD_fmQGKwm7"
      },
      "source": [
        "# 7️⃣ Training Complex Adapter Combinations\n",
        "\n",
        "In this notebook, we explore how to easily configure complex combinations of different adapter methods with `ConfigUnion`. We show how to re-build the adapter setup used in Mao et al., 2022](https://arxiv.org/pdf/2110.07577.pdf) (UniPELT).\n",
        "For a basic introduction into the training setup with _Adapters_, please first refer to [the introductory training notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/01_Adapter_Training.ipynb).\n",
        "\n",
        "As training task, we chose abstractive summarization on the **XSum** dataset ([Narayan et al., 2018](https://arxiv.org/pdf/1808.08745.pdf)). As base model, we select **T5**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zcuSyuzjKzoW"
      },
      "source": [
        "## Installation\n",
        "\n",
        "First, let's install the required libraries:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3P8cRBAXK1fW",
        "outputId": "1f8b27f3-caa2-471d-86d2-287a054c90b6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for rouge_score (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install -qq adapters datasets evaluate nltk accelerate rouge_score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DjiwZ8a0K7mJ"
      },
      "source": [
        "## Dataset Preprocessing\n",
        "\n",
        "Before we start to train our adapter, we first prepare the training data. The XSum dataset can be loaded via HuggingFace `datasets`.\n",
        "\n",
        "**Note:** To keep training time short in this notebook, we only load a small subset of the full dataset. For good results, make sure to train on the full dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 291,
          "referenced_widgets": [
            "799ddb94fff34164928166a3b9c1f25f",
            "21603eac1b8e46a0b6c29fabd51a3a1a",
            "be1d16ae210e4fb9aaec770966e37026",
            "da468c563ea94c519f1a4248a140c2ab",
            "4b6de7888ca34e86b7fc53af6c01f997",
            "56061c37acae4359b16114d1538c6edf",
            "09c631e83d0e4ce3bf53489dda4aaed0",
            "258fb1eaaf7c4eafbe5f27c15b91ac96",
            "f120e0ca40904eb685d71c192c46d085",
            "16f67d04ce674e3aa199d6c3f4ddeeb3",
            "ac2d36ffcfc3449c9c58e227ac257fe3",
            "7e2da1d810d740e6a23079a795de3e5e",
            "154457d60d0c454093b6c1e8803c5698",
            "46348706553b487b94c8185a780f22cf",
            "bc66d3376ba445549cb3555bdf0979f7",
            "a395ba5eec914e4e96d2e4fb80c94125",
            "0c7137af2b02449ba4f1845e23223261",
            "8cefcd5d608a4b9c9cf1e58a7a53fef5",
            "143332052e7941ee92218479164d46bd",
            "307b00fdf879406ab84903186619b328",
            "b7e6c12f17c045a2a6da3473feb3a6c1",
            "61126d24d0fc4833b255cf8f69443454",
            "73d1cd4c73ea49039c1d9ce28d211777",
            "e361def49aae4a50868191d662fb7680",
            "a5f49a4ef173413b8683b1fddd2ee230",
            "7f48995a17e94f54b251011eb38a18c6",
            "2b2e11219a7b40feb18817b8af5296e4",
            "0f832da6510f4df694a9d76b36879e68",
            "952546b7838942ae995dd36116a0a804",
            "f2ea54c688534830ab0f1bd393a8e228",
            "c860a847ab5e42a7835596cf83a1ac5f",
            "4732e7d36fa047baa961f40e0eca9a44",
            "383502ba723e4f0dbb3cd714e6994a28",
            "6fa3dad9dad1494e857cb69a063a9d85",
            "ae71c5392e364d67b725fb4d7215c912",
            "5bb84c5a5b1c48458eed1f949c873caa",
            "4af51b0bb25a4cf499c332ff0a8e6b75",
            "78bf192c0c0b4a2db0f819a56b790284",
            "0b7db13383f04f2bbf4746f754606e7a",
            "9262cb4684594fc3ba19da63f70e0e72",
            "26dea2fec67848368a84c74ef2489bd2",
            "cd30a514bc9d4568a03fa79d9d06295b",
            "0d12ffaa0f65473ba30fac5f380e5631",
            "8f7c2851a8af4657aa8cb4adb4d7be2a",
            "59dd13c441ce406b96147de93e2e4f85",
            "f407d6db93d04200a4393773a6c59619",
            "0aa296e2ef1b4255bdf9b139c4d1f056",
            "f68c4941385141efa44e5eb72311de7c",
            "9e2abe8adaa64590b63049c3b2788b20",
            "43901e3784cf4a28a31220355f1c0d2a",
            "451c29326ed841538fd502e9b71513c7",
            "bfdbea35fa32485782507eb205a2826c",
            "93e8215cffc34c1cabeeeaca5848e6b8",
            "19a026a043ab4c3ca70e83268f6486d5",
            "dc41ed508fd141c994f6f3c119201c7e",
            "000408980ba541df83a1bb54a38b791a",
            "0bd32de117a246419914e3c59965f402",
            "95036eb6c828446bad1aee5dd17258ba",
            "7bde20d36afe4ff494a7893def3c8de8",
            "9d2abdbccd5f491c8148a4e9e9fe811e",
            "66e6cbf4dc5748da89c5056761174998",
            "63b3b86c7eb745d09f02b311f4fa80cd",
            "558ca25096f34d738b6f3fcb4c942916",
            "2b0307168ba049d8abeafb1bf63db7d3",
            "3e7dbe25791f49a4a2e22f818a93d3dc",
            "4df0bfb480cf48bc89b02462114692cc",
            "155a28d0e35b41048302227e762f48a0",
            "535d678a3f4d48cbadb7f5515bc785b4",
            "177f01b58b63442884c6c9c295d16075",
            "bb7afe8dc54540588ae1ef1d3c33f011",
            "f47b5fc8249145d98d376ae9847ed216",
            "704cb077d15646b9bdead4da478f022c",
            "844c628f795d4f0aa19b389f086dd677",
            "3f4477ba473f440ca3fc5110a533db3d",
            "fbe445ba51d54c9eb4f3731918c42749",
            "d1ab22f684c6496795498dea96456944",
            "c6cf3b270f664262b01642c34d83aae0",
            "2326ea1997844aec8f647dd0bd76949a",
            "f4a45be5c2174588a1fdf94db78a964a",
            "25d44943ec7240f49d956754902f1577",
            "232716d4002243b79fcd17107810af9d",
            "f909ec0ef3414bbda0dbaadb5e26e775",
            "fdb6e0bf0c4b440d81087050aac2d9b3",
            "1a781022a9b448328ab70a7c17c0ae89",
            "37c807db4d134069871b8c398650dd4f",
            "85deadbf4aaa445bac59b09309654c4f",
            "0cc2cf4417c04091bcbd72d6fb92fe1c",
            "11ac8002832c46f3847e2582e85a9a08"
          ]
        },
        "id": "LCLS9RYeK83s",
        "outputId": "f28e27b1-a459-47e1-e1cf-3f2283e893bb"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "train_dataset = load_dataset(\"xsum\", split=\"train[:5000]\")\n",
        "val_dataset = load_dataset(\"xsum\", split=\"validation[:500]\")\n",
        "train_dataset.num_rows"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8DnDtZHpK-_t"
      },
      "source": [
        "Every dataset sample has an input document and a short summary:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "huLjPAKHLA1g",
        "outputId": "b311d1f0-c40c-45b0-9149-03d5a544c3ae"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'document': 'The full cost of damage in Newton Stewart, one of the areas worst affected, is still being assessed.\\nRepair work is ongoing in Hawick and many roads in Peeblesshire remain badly affected by standing water.\\nTrains on the west coast mainline face disruption due to damage at the Lamington Viaduct.\\nMany businesses and householders were affected by flooding in Newton Stewart after the River Cree overflowed into the town.\\nFirst Minister Nicola Sturgeon visited the area to inspect the damage.\\nThe waters breached a retaining wall, flooding many commercial properties on Victoria Street - the main shopping thoroughfare.\\nJeanette Tate, who owns the Cinnamon Cafe which was badly affected, said she could not fault the multi-agency response once the flood hit.\\nHowever, she said more preventative work could have been carried out to ensure the retaining wall did not fail.\\n\"It is difficult but I do think there is so much publicity for Dumfries and the Nith - and I totally appreciate that - but it is almost like we\\'re neglected or forgotten,\" she said.\\n\"That may not be true but it is perhaps my perspective over the last few days.\\n\"Why were you not ready to help us a bit more when the warning and the alarm alerts had gone out?\"\\nMeanwhile, a flood alert remains in place across the Borders because of the constant rain.\\nPeebles was badly hit by problems, sparking calls to introduce more defences in the area.\\nScottish Borders Council has put a list on its website of the roads worst affected and drivers have been urged not to ignore closure signs.\\nThe Labour Party\\'s deputy Scottish leader Alex Rowley was in Hawick on Monday to see the situation first hand.\\nHe said it was important to get the flood protection plan right but backed calls to speed up the process.\\n\"I was quite taken aback by the amount of damage that has been done,\" he said.\\n\"Obviously it is heart-breaking for people who have been forced out of their homes and the impact on businesses.\"\\nHe said it was important that \"immediate steps\" were taken to protect the areas most vulnerable and a clear timetable put in place for flood prevention plans.\\nHave you been affected by flooding in Dumfries and Galloway or the Borders? Tell us about your experience of the situation and how it was handled. Email us on selkirk.news@bbc.co.uk or dumfries@bbc.co.uk.',\n",
              " 'summary': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.',\n",
              " 'id': '35232142'}"
            ]
          },
          "execution_count": 3,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "train_dataset[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In this example, we use a `t5-small` model for faster training. Feel free to use a larger checkpoint for better results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "eeN7fa-xR8HJ"
      },
      "outputs": [],
      "source": [
        "model_id = \"t5-small\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sYD2HAvSLFwA"
      },
      "source": [
        "Now, we need to encode all dataset samples to valid inputs for our Transformer model. We load the tokenizer matching the model we want to train. Using `dataset.map()`, we can pass the full dataset through the tokenizer in batches:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 81,
          "referenced_widgets": [
            "55eaee128e444763af7ea515be45ec0b",
            "d63286c4af6b4e639d0bdd040e08834b",
            "e9428933942946f987065015a3c1f0e7",
            "9435e13da75742a1aa9dbda76848822c",
            "ada4e8599ffd4892baafe488c2230245",
            "cf6e028ad17e4f44910e85715ea55ac3",
            "07d59fe106324266b862a4258b40c555",
            "88e19c4a33be40c28280f219d5c3c515",
            "73c5cbeda45043199caed203c3164283",
            "1c79ce71b3d943f792e30b1ded57aecb",
            "b4edc69e40114db3a2b4b03e1b261e13",
            "cb9cedfacba14d5282b1478d55c2075a",
            "2d5a0f2e25bf4ca3b0a13f3af9bd668d",
            "1fe067e9e9d04ecb98e93785a6d6f4f0",
            "58307b2ad68b40bab1102da1f638edc5",
            "ad5a27f51c68434ea7d80f609a8bd654",
            "581a0acca27c4a9084ca32720b1c98ad",
            "20627f86d26f4dfabebde438d1a14d41",
            "37c9ef85427e436b9be728f4862082c9",
            "5304bffd1a354a4a8d15d4255bbf850d",
            "ab6a903089f24c908a8be45f0da8d857",
            "2087c4b4ae994af190399116f93a7437"
          ]
        },
        "id": "oDHhicggLHW7",
        "outputId": "255f313c-89af-4f8e-bf4a-b4a19d88f536"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
        "\n",
        "prefix = \"summarize: \"\n",
        "max_input_length = 1024\n",
        "max_target_length = 128\n",
        "\n",
        "def encode_batch(examples):\n",
        "  \"\"\"Encodes a batch of input data using the model tokenizer.\"\"\"\n",
        "  inputs = [prefix + doc for doc in examples[\"document\"]]\n",
        "  model_inputs = tokenizer(inputs, max_length=None, padding=\"max_length\", truncation=True)\n",
        "\n",
        "  # Setup the tokenizer for targets\n",
        "  labels = tokenizer(text_target=examples[\"summary\"], max_length=max_target_length, padding=\"max_length\", truncation=True)\n",
        "\n",
        "  labels[\"input_ids\"] = [\n",
        "      [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels[\"input_ids\"]\n",
        "  ]\n",
        "\n",
        "  model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
        "  return model_inputs\n",
        "\n",
        "# Encode the input data\n",
        "train_dataset = train_dataset.map(encode_batch, batched=True)\n",
        "val_dataset = val_dataset.map(encode_batch, batched=True)\n",
        "# Transform to pytorch tensors and only output the required columns\n",
        "train_dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
        "val_dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bwl-pv89LJfa"
      },
      "source": [
        "Now we're ready to train our model..."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XgykXOBHLLFK"
      },
      "source": [
        "## Training\n",
        "\n",
        "We use a pre-trained T5 model checkpoint from the Hugging Face Hub. We load it with [`AutoAdapterModel`](https://docs.adapterhub.ml/classes/models/auto.html), which comes built-in with all adapter functionality. [Learn more](https://docs.adapterhub.ml/prediction_heads.html#adaptermodel-classes)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 286,
      "metadata": {
        "id": "lENbN034LMOk"
      },
      "outputs": [],
      "source": [
        "from adapters import AutoAdapterModel\n",
        "\n",
        "model = AutoAdapterModel.from_pretrained(model_id)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "**Now we're ready to build our adapter setup!**\n",
        "\n",
        "The UniPELT framework ([Mao et al., 2022](https://arxiv.org/pdf/2110.07577.pdf)) presented one approach of combining multiple types of adapter blocks in a single, combined setup.\n",
        "Visualized, UniPELT looks roughly like this:\n",
        "\n",
        "<div>\n",
        "<img src=\"https://docs.adapterhub.ml/_images/unipelt.png\" width=\"250\"/>\n",
        "</div>\n",
        "\n",
        "We can see that UniPELT is built from three well-known single adapter methods: 1) LoRA, 2) Prefix Tuning and 3) Sequential Bottleneck.\n",
        "\n",
        "_Adapters_ provides an easy way to flexibly build these composed configurations: [`ConfigUnion`](https://docs.adapterhub.ml/classes/adapter_config.html#adapters.ConfigUnion). `ConfigUnion` basically acts as a container holding multiple child adapter configs. [Learn more](https://docs.adapterhub.ml/method_combinations.html).\n",
        "\n",
        "With `ConfigUnion`, we can define UniPELT as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from adapters import ConfigUnion, PrefixTuningConfig, SeqBnConfig, LoRAConfig\n",
        "\n",
        "config = ConfigUnion(\n",
        "    LoRAConfig(r=8, use_gating=True),\n",
        "    PrefixTuningConfig(prefix_length=10, use_gating=True),\n",
        "    SeqBnConfig(reduction_factor=16, use_gating=True),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QY4Xiy9iLOVO"
      },
      "source": [
        "\n",
        "We now add a new adapter to our model by calling `add_adapter()`. We pass the name (`\"xsum\"`) and the adapter configuration we defined using `ConfigUnion`.\n",
        "\n",
        "Next, we add a seq2seq head. It's convenient to give the prediction head the same name as the adapter. This allows us to activate both together in the next step. The `train_adapter()` method does two things:\n",
        "\n",
        "1. It freezes all weights of the pre-trained model, so only the adapter weights are updated during training.\n",
        "2. It activates the adapter and the prediction head such that both are used in every forward pass."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 288,
      "metadata": {
        "id": "saGtTGogLPVf"
      },
      "outputs": [],
      "source": [
        "model.add_adapter(\"xsum\", config=config)\n",
        "\n",
        "# Add a matching classification head\n",
        "model.add_seq2seq_lm_head(\"xsum\")\n",
        "\n",
        "# Activate the adapter\n",
        "model.train_adapter(\"xsum\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zyLopENcLSEq"
      },
      "source": [
        "For training an adapter, we make use of the `Seq2SeqAdapterTrainer` class built-in into _Adapters_. This class is largely identical to _Transformer_'s `Seq2SeqTrainer`, with some helpful tweaks e.g. for checkpointing only adapter weights.\n",
        "\n",
        "We configure the training process using a `Seq2SeqTrainingArguments` object and define a method that will calculate the evaluation accuracy in the end. We pass both, together with the training and validation split of our dataset, to the trainer instance.\n",
        "\n",
        "**Note the differences in hyperparameters compared to full fine-tuning.** Adapter training usually requires a few more training epochs than full fine-tuning."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 289,
      "metadata": {
        "id": "QoTEQbhQLTGB"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer\n",
        "from adapters import Seq2SeqAdapterTrainer\n",
        "\n",
        "training_args = Seq2SeqTrainingArguments(\n",
        "    learning_rate=5e-5,\n",
        "    num_train_epochs=6,\n",
        "    per_device_train_batch_size=16,\n",
        "    per_device_eval_batch_size=16,\n",
        "    evaluation_strategy = \"epoch\",\n",
        "    logging_steps=200,\n",
        "    output_dir=\"./training_output\",\n",
        "    overwrite_output_dir=True,\n",
        "    predict_with_generate=True,\n",
        "    fp16=True,\n",
        "    remove_unused_columns=False,\n",
        "    label_names=[\"labels\"]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Some additional logic for computing metrics:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 290,
      "metadata": {
        "id": "4f3B5pKiR8HL"
      },
      "outputs": [],
      "source": [
        "# This is copied & adapted from https://github.com/huggingface/notebooks/blob/main/examples/summarization.ipynb\n",
        "from evaluate import load\n",
        "import nltk\n",
        "\n",
        "nltk.download(\"punkt\")\n",
        "\n",
        "metric = load(\"rouge\")\n",
        "\n",
        "def compute_metrics(eval_pred):\n",
        "    predictions, labels = eval_pred\n",
        "    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n",
        "    # Replace -100 in the labels as we can't decode them.\n",
        "    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
        "    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
        "\n",
        "    # Rouge expects a newline after each sentence\n",
        "    decoded_preds = [\"\\n\".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]\n",
        "    decoded_labels = [\"\\n\".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]\n",
        "\n",
        "    # Note that other metrics may not have a `use_aggregator` parameter\n",
        "    # and thus will return a list, computing a metric for each sentence.\n",
        "    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)\n",
        "    # Extract a few results\n",
        "    result = {key: value * 100 for key, value in result.items()}\n",
        "\n",
        "    # Add mean generated length\n",
        "    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]\n",
        "    result[\"gen_len\"] = np.mean(prediction_lens)\n",
        "\n",
        "    return {k: round(v, 4) for k, v in result.items()}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 291,
      "metadata": {
        "id": "fkC6LpmMR8HL"
      },
      "outputs": [],
      "source": [
        "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
        "\n",
        "trainer = Seq2SeqAdapterTrainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=train_dataset,\n",
        "    eval_dataset=val_dataset,\n",
        "    data_collator=data_collator,\n",
        "    tokenizer=tokenizer,\n",
        "    compute_metrics=compute_metrics\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LtSsZNJsLVk4"
      },
      "source": [
        "Start the training 🚀"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 293,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 431
        },
        "id": "isGxmG9LLnhs",
        "outputId": "5a7819db-7a80-469b-f653-999292af4c6b"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1878' max='1878' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1878/1878 16:25, Epoch 6/6]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Rouge1</th>\n",
              "      <th>Rouge2</th>\n",
              "      <th>Rougel</th>\n",
              "      <th>Rougelsum</th>\n",
              "      <th>Gen Len</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>4.749900</td>\n",
              "      <td>3.782296</td>\n",
              "      <td>22.108700</td>\n",
              "      <td>3.826100</td>\n",
              "      <td>17.717600</td>\n",
              "      <td>17.700600</td>\n",
              "      <td>18.830000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>4.014100</td>\n",
              "      <td>3.535523</td>\n",
              "      <td>22.913200</td>\n",
              "      <td>4.204500</td>\n",
              "      <td>17.920600</td>\n",
              "      <td>17.902700</td>\n",
              "      <td>18.976000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>3.906600</td>\n",
              "      <td>3.440636</td>\n",
              "      <td>23.491500</td>\n",
              "      <td>4.494100</td>\n",
              "      <td>18.485900</td>\n",
              "      <td>18.515400</td>\n",
              "      <td>18.964000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>3.826700</td>\n",
              "      <td>3.392781</td>\n",
              "      <td>23.644400</td>\n",
              "      <td>4.461600</td>\n",
              "      <td>18.476800</td>\n",
              "      <td>18.502300</td>\n",
              "      <td>18.994000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>3.802500</td>\n",
              "      <td>3.366829</td>\n",
              "      <td>23.516100</td>\n",
              "      <td>4.397600</td>\n",
              "      <td>18.277100</td>\n",
              "      <td>18.291700</td>\n",
              "      <td>18.998000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>3.769400</td>\n",
              "      <td>3.359185</td>\n",
              "      <td>23.382400</td>\n",
              "      <td>4.438500</td>\n",
              "      <td>18.262200</td>\n",
              "      <td>18.275100</td>\n",
              "      <td>18.990000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1273: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1273: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1273: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=1878, training_loss=3.978136690279927, metrics={'train_runtime': 985.7264, 'train_samples_per_second': 30.434, 'train_steps_per_second': 1.905, 'total_flos': 5071442595840000.0, 'train_loss': 3.978136690279927, 'epoch': 6.0})"
            ]
          },
          "execution_count": 293,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0BLPys_0LuI5"
      },
      "source": [
        "Looks good! Let's evaluate our adapter on the validation split of the dataset to see how well it learned:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 294,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 221
        },
        "id": "hnct-N6dLvOd",
        "outputId": "ecfed495-36d5-4314-e1b4-0c95a16ccf82"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='32' max='32' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [32/32 00:36]\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "{'eval_loss': 3.359184503555298,\n",
              " 'eval_rouge1': 23.3824,\n",
              " 'eval_rouge2': 4.4385,\n",
              " 'eval_rougeL': 18.2622,\n",
              " 'eval_rougeLsum': 18.2751,\n",
              " 'eval_gen_len': 18.99,\n",
              " 'eval_runtime': 38.6004,\n",
              " 'eval_samples_per_second': 12.953,\n",
              " 'eval_steps_per_second': 0.829,\n",
              " 'epoch': 6.0}"
            ]
          },
          "execution_count": 294,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.evaluate()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XTtjRB_VNJbL"
      },
      "source": [
        "We can put our trained model into a _Transformers_ pipeline to be able to make new predictions conveniently:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 295,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x-BYe1jeNK98",
        "outputId": "7a776769-9a1d-4ad1-8292-92a7ed90e10a"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The model 'T5AdapterModel' is not supported for . Supported models are ['BartForConditionalGeneration', 'BigBirdPegasusForConditionalGeneration', 'BlenderbotForConditionalGeneration', 'BlenderbotSmallForConditionalGeneration', 'EncoderDecoderModel', 'FSMTForConditionalGeneration', 'GPTSanJapaneseForConditionalGeneration', 'LEDForConditionalGeneration', 'LongT5ForConditionalGeneration', 'M2M100ForConditionalGeneration', 'MarianMTModel', 'MBartForConditionalGeneration', 'MT5ForConditionalGeneration', 'MvpForConditionalGeneration', 'NllbMoeForConditionalGeneration', 'PegasusForConditionalGeneration', 'PegasusXForConditionalGeneration', 'PLBartForConditionalGeneration', 'ProphetNetForConditionalGeneration', 'SeamlessM4TForTextToText', 'SwitchTransformersForConditionalGeneration', 'T5ForConditionalGeneration', 'UMT5ForConditionalGeneration', 'XLMProphetNetForConditionalGeneration'].\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "[{'summary_text': 'The Disney Princess has a new film that starred in the film \"The Darkness of'}]"
            ]
          },
          "execution_count": 295,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from transformers import SummarizationPipeline\n",
        "\n",
        "summarizer = SummarizationPipeline(model=model, tokenizer=tokenizer, device=training_args.device.index)\n",
        "\n",
        "summarizer(\"\"\"The film about a princess's mythical journey in ancient Polynesia took an estimated $81.1m (£65.3m) on its debut. That makes it the second-highest Thanksgiving debut of all time, behind Disney's Frozen, which took $93.6m (£75.3m) on its release in 2013. Some observers have said that Moana and its merchandise are appropriating Pacific Island culture. Disney withdrew a children's costume promoting the film after activists branded it \"brownface\", or mocking of their culture by stereotyping. The costume, a full-body suit with brown skin, traditional tattoos, grass skirt and bone necklace, represented the character Maui, considered a demi-god and ancestor by many Polynesians. Disney said it regretted any offence. JK Rowling's Fantastic Beasts and Where to Find Them fell to second on the US chart, taking $65.8m (£53m). Gossip surrounding Brad Pitt's marriage break-up failed to spark a huge amount of interest in his World War Two romance Allied, which also stars Marion Cotillard. It took $18m (£14.4m) over the long weekend, having cost $85m (£68.5m) to make, landing in fourth spot behind Doctor Strange. Kyle Davies, Paramount's head of domestic distribution, said the film appealed to \"older audiences\" but noted those \"don't storm the theatres [on] weekend one\". \"I think they're going to take their time,\" he added. Warren Beatty fared worse - his first film in 15 years, the 1950s Hollywood comedy Rules Don't Apply, took just $2.2m (£1.7m). The film is Beatty's first directed feature since 1998's Bulworth. Bad Santa 2, released 13 years after the original and again starring Billy Bob Thornton, did a little better, taking $9m (£7.3m). Follow us on Facebook, on Twitter @BBCNewsEnts, or on Instagram at bbcnewsents. If you have a story suggestion email entertainment.news@bbc.co.uk.\"\"\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ANVQrhIsNPOu"
      },
      "source": [
        "At last, we can also extract the adapter from our model and separately save it for later reuse. Note the size difference compared to a full model!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 285,
      "metadata": {
        "id": "50FhEtBrNP_C"
      },
      "outputs": [],
      "source": [
        "model.save_adapter(\"./final_adapter\", \"xsum\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5sAJKaQeNREg"
      },
      "source": [
        "**Share your work!**\n",
        "\n",
        "The final step after successful training is to share our adapter with the world!\n",
        "_Adapters_ seamlessly integrates with the [Hugging Face Model Hub](https://huggingface.co/models), so you can publish your trained adapter with a single method call:\n",
        "\n",
        "**Important:** Make sure you're properly authenticated with your Hugging Face account before running this method. You can log in by running `huggingface-cli login` on your terminal."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DVGS6xvcR8HM"
      },
      "outputs": [],
      "source": [
        "model.push_adapter_to_hub(\n",
        "    \"my-awesome-adapter\",\n",
        "    \"xsum\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "78jqOTrLR8HN"
      },
      "source": [
        "This will create a repository _my-awesome-adapter_ under your username, generate a default adapter card as README.md and upload the adapter named `xsum` together with the adapter card to the new repository. [Learn more](https://docs.adapterhub.ml/huggingface_hub.html)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
